Connectomics Spike Detection

Now that we have imported the raw data into a nice hdf5 data storage, we are ready to look at the data. The first step is to look at the raw signal, and find a way to robustly find spikes.

First we start with the basics again: Import some essential libraries and open the HDF5 Store

In [1]:
# Some core imports
import os
import sys
from subprocess import call
import time
import pandas
import numpy as np
import h5py

import theano
import theano.tensor as T
import theano.tensor.nnet as tnn
import matplotlib.pyplot as plt

# These are IPython specific
from IPython.display import display, clear_output
%matplotlib inline
%load_ext cythonmagic

Couldn't import dot_parser, loading of dot files will not be possible.

In [2]:
# Core configuration

datadir = "/data/quick/connectomics"
mirror = ""
global nbdir
if (not 'nbdir' in globals()):
    nbdir = os.getcwd()

In [27]:
store = h5py.File(datadir + '/store.h5')

In [4]:
print store.keys()

[u'highcc', u'highcon', u'lowcc', u'normal-1', u'normal-2', u'normal-3', u'normal-4', u'normal-4-lownoise', u'small-1', u'small-2', u'small-3', u'small-4', u'small-5', u'small-6', u'test', u'valid']

In [28]:
net = store['small-1'] # We will use this network for initial testing purposes
print net.keys()

[u'connectionMatrix', u'fluorescence', u'networkPositions', u'signalDistanceMatrix', u'spatialDistanceMatrix', u'spikes']

Matplotlib Animation Code

This code allows us to write videos using Matplotlib and display them in the notebook.

In [5]:
import matplotlib.animation as manimation
from IPython.display import HTML, FileLink
from base64 import b64encode
import matplotlib


def display_video_inline(filename):
    video = open(filename, "rb").read()
    video_encoded = b64encode(video).decode('ascii')
    video_tag = '<video controls alt="Video" src="data:video/x-m4v;base64,{0}"/>'.format(video_encoded)
    return HTML(data=video_tag)

def display_video(filename):
    link = FileLink(os.path.abspath(filename))
    video_tag = '<video controls alt="Video" src="{0}{1}" /><br/><a href="{0}{1}">Source</a>'.format(link.url_prefix,link.path)
    return HTML(data=video_tag)

def display_video_link(filename):
    link = FileLink(os.path.abspath(filename))
    video_tag = '<a href="{0}{1}">{1}</a>'.format(link.url_prefix,link.path)
    return HTML(data=video_tag)
video_dpi = 100
FFMpegWriter = manimation.writers['ffmpeg']
metadata = dict(title='Connectomics Visualization Video', artist='K. Londenberg')

Fluorescence distribution

First, we are going to learn something about how the fluorescence is distributed.

In [6]:
f = np.array(net['fluorescence'].value)
print f.shape
#f = f[:,:]

allf = f.reshape((f.shape[0]*f.shape[1]))
# the histogram of the data
n, bins, patches = plt.hist(allf, 50, normed=1, facecolor='g', alpha=0.75)
plt.title('Fluorescence distribution')

(100, 179498)

This doesn't help much. There's a smooth gradient from signal (spikes) to noise as it seems. If the distribution would have been bi- or multimodal, things would have been easier. Let's look at the autocorrelation, i.e. how strong the fluorescence of single neurons changes from time step to time step.

In [7]:
# Now, we take a look at the distribution of the time-derivative
f = np.array(net['fluorescence'].value)
#print f.shape
f = f[0,1:]-f[0,0:-1]

allf = f #f.reshape((f.shape[0]*f.shape[1]))
print f.shape                                                                   
# the histogram of the data
n, bins, patches = plt.hist(allf, 50, normed=1, facecolor='g', alpha=0.75)
plt.title('Fluorescence autocorrelation')


Looks like a near-perfect gaussian random walk - except, of course, it isn't. We have to investigate further. First we check if it's like that for every neuron.

In [31]:
f = net['fluorescence'].value
fig = plt.figure()
writer = FFMpegWriter(fps=8, metadata=metadata, extra_args=['-acodec', 'libfaac', '-vcodec', 'libx264'])
video = "neuron_autocorrelation.mp4"
videopath = "%s/%s" % (nbdir, video)
max_frames = 30
with writer.saving(fig,videopath , video_dpi):
    for i in range(f.shape[0]):
        if (i>max_frames):
        print "Writing Frame #%d / %d " % (i, f.shape[0])
        plt.xlim(-0.3, 0.3)
        plt.ylim(0, 10)
        allf = f[i,1:]-f[i,0:-1]
        n, bins, patches = plt.hist(allf, 50, normed=1, facecolor='g', alpha=0.75)
        plt.title('Neuron %d Autocorrelation' % (i))
print "Done"

So all of the Neurons have this kind of pseudo Gaussian random walk behaviour. Now, let's look at some examples of how the actual spikes look like in the fluorescence data.

In [32]:
f = net['fluorescence'].value
fig = plt.figure()
#writer = FFMpegWriter(fps=3, metadata=metadata, extra_args=['-acodec', 'libfaac', '-vcodec', 'libx264', '-qp', '0'])
writer = FFMpegWriter(fps=3, metadata=metadata, extra_args=['-vcodec', 'libx264','-qp', '0'])

video = "neuron_fluorescence.mp4"
videopath = "%s/%s" % (nbdir, video)
max_frames = 30
with writer.saving(fig,videopath , video_dpi):
    for i in range(f.shape[0]):
        if (i>max_frames):
        print "Writing Frame #%d / %d " % (i, f.shape[0])
        plt.ylim(-0.1, 1.1)
        allf = f[i,:500]#-f[i,0:-1]
        plt.plot(np.linspace(0.0, len(allf)*20.0,len(allf)), allf, 'r-')
        plt.xlabel('Time in ms')
        plt.ylabel('Fluorescence level')
        plt.title('Neuron %d Fluorecence' % (i))
print "Done"

Looks like we need to smooth the signal. We will use the simple windowed smoothing function from

In [34]:
import numpy

def smooth(x,window_len=11,window='hanning'):
    """smooth the data using a window with requested size.
    This method is based on the convolution of a scaled window with the signal.
    The signal is prepared by introducing reflected copies of the signal 
    (with the window size) in both ends so that transient parts are minimized
    in the begining and end part of the output signal.
        x: the input signal 
        window_len: the dimension of the smoothing window; should be an odd integer
        window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'
            flat window will produce a moving average smoothing.

        the smoothed signal

    see also: 
    numpy.hanning, numpy.hamming, numpy.bartlett, numpy.blackman, numpy.convolve
    TODO: the window parameter could be the window itself if an array instead of a string
    NOTE: length(output) != length(input), to correct this: return y[(window_len/2-1):-(window_len/2)] instead of just y.

    if x.ndim != 1:
        raise ValueError, "smooth only accepts 1 dimension arrays."

    if x.size < window_len:
        raise ValueError, "Input vector needs to be bigger than window size."

    if window_len<3:
        return x

    if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
        raise ValueError, "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'"

    if window == 'flat': #moving average

    return y

In [37]:
f = net['fluorescence'].value
fig = plt.figure()
writer = FFMpegWriter(fps=3, metadata=metadata, extra_args=['-acodec', 'libfaac', '-vcodec', 'libx264'])

video = "neuron_smoothed_fluorescence.mp4"
videopath = "%s/%s" % (nbdir, video)
max_frames = 30
with writer.saving(fig,videopath , video_dpi):
    for i in range(f.shape[0]):
        if (i>max_frames):
        print "Writing Frame #%d / %d " % (i, f.shape[0])
        allf = smooth(f[i,:500], 16)#-f[i,0:-1]
        plt.plot(np.linspace(0.0, len(allf)*20.0,len(allf)), allf, 'r-')
        plt.xlabel('Time in ms')
        plt.ylabel('Fluorescence level')
        plt.title('Neuron %d Fluorecence' % (i))
print "Done"

Looks much cleaner. But we risk to lose high-frequency and exact timing information. We are going to take a look at Wavelet Analysis using the "PyWavelets" library to see if we can get time-localized signals in different frequency domains. to separate different types of signals.

There is example code for multilevel decomposition available at

First of all, we import the pywt package and select one of the wavelets. The wavelet browser at allows to select a suitable one nicely.

In [16]:
import pywt

haar = pywt.Wavelet('haar')
w1 = pywt.Wavelet('rbio1.3')
w2 = pywt.Wavelet('rbio2.2')

print pywt.dwt_max_level(len(f), w2)


These functions effectively implement high- and lowband filters on the wavelet coefficients of a multi-level wavelet decomposition.

In [17]:
def rough_coeffs(coeff, max_level):
    results = list(coeff)
    for i in range(max_level+1,len(coeff)):
        results[i] = np.zeros_like(coeff[i])
    return results

def fine_coeffs(coeff, min_level):
    results = list(coeff)
    for i in range(0, min_level):
        results[i] = np.zeros_like(coeff[i])
    return results

Now, let's do a multilevel wavelet decomposition, and a corresponding reconstruction with all the high-frequency coefficients set to zero.

In [18]:
f = net['fluorescence'].value
data = f[99,:]
haarcoeffs = pywt.wavedec(data, haar)
w1coeffs = pywt.wavedec(data, w1)
w2coeffs = pywt.wavedec(data, w2)

fig = plt.figure(figsize=(12,8))
allf = smooth(data[:500], 28)

plt.plot(np.linspace(0.0, len(allf)*20.0,len(allf)), allf, 'r-', alpha=0.6, label='Smooth 28')

allf = smooth(data[:500], 5)

#plt.plot(np.linspace(0.0, len(allf)*20.0,len(allf)), allf, 'r-', alpha=0.2, title='Smooth 5')
#for i in range(8,len(coeffs)):

rec = pywt.waverec(rough_coeffs(haarcoeffs, 11), haar)
plt.plot(np.linspace(0.0, len(rf)*20.0,len(rf)), rf, 'g-', alpha=0.7, label='Haar')

rec = pywt.waverec(rough_coeffs(w1coeffs, 11), w1)
plt.plot(np.linspace(0.0, len(rf)*20.0,len(rf)), rf, 'b-', alpha=0.7, label='rbio1.3')

rec = pywt.waverec(rough_coeffs(w2coeffs,11), w2)
plt.plot(np.linspace(0.0, len(rf)*20.0,len(rf)), rf, 'y-', alpha=0.7, label='rbio2.2')

plt.xlabel('Time in ms')
plt.ylabel('Fluorescence level')
plt.title('Neuron %d Fluorescence' % (i))

This is a much more sparse representation and allows us to compress the data while keeping most information. The "rbio2.2" wavelet looks interesting, but I think we would need a custom wavelet to get really good results.

Windowed Up/Down Volatility Metrics

Let's try something else. We can simply slide a window along the signal and record for every data point how much above the minimum it is within the sliding window.

We are going to implement these in optimized Cython code. See Cython Numpy Tutorial and this Short IPython cythonmagic Tutorial for an explanation of the technical aspects of writing Cython code dealing with Numpy arrays within an IPython Notebook.

In [19]:
cimport cython
cimport numpy as np
import numpy as np
from libc.math cimport exp, sqrt, pow, log, erf

DTYPE = np.float32
ctypedef np.float32_t DTYPE_t

def volatility(np.ndarray[DTYPE_t, ndim=1] src,int windowsize):
    Calculate the windowed volatility (max-min values) of a numpy array
        src: numpy array with dtype=numpy.float32 and one dimension
        windowsize: integer giving the size of the sliding window
    returns: numpy array of same shape and type as src
    assert src.dtype == DTYPE
    assert windowsize>0
    cdef int length = src.shape[0]
    assert length>windowsize
    cdef DTYPE_t cmax = src[0]
    cdef DTYPE_t cmin = src[0]
    cdef np.ndarray result = np.zeros([length], dtype=DTYPE)
    cdef np.ndarray rotwindow = np.zeros([windowsize], dtype=DTYPE)
    cdef DTYPE_t cval = src[0]
    rotwindow[:] = cval
    cmax = cval
    cmin = cval
    cdef int rotpos = 0
    cdef int i
    cdef int remax = 0
    cdef int remin = 0
    for i in range(length):
        cval = src[i]
        # Does the current min/max drop out of the rotating window ? If so, remember that
        if (rotwindow[rotpos]==cmax):
            remax = 1
        if (rotwindow[rotpos]==cmin):
            remin = 1
        # Write current source value into rotating window
        rotwindow[rotpos] = cval
        # And increment window position
        rotpos = (rotpos + 1) % windowsize
        if (remax==1):
            cmax = np.max(rotwindow)
            if (cval>cmax):
                cmax = cval
        if (remin==1):
            cmin = np.min(rotwindow)
            if (cval<cmin):
                cmin = cval
        result[i] = cmax-cmin
    return result

def upspan(np.ndarray[DTYPE_t, ndim=1] src,int windowsize):
    Calculate the windowed up-span (current - min value) of a numpy array
        src: numpy array with dtype=numpy.float32 and one dimension
        windowsize: integer giving the size of the sliding window
    returns: numpy array of same shape and type as src
    assert src.dtype == DTYPE
    assert windowsize>0
    cdef int length = src.shape[0]
    assert length>windowsize
    cdef DTYPE_t cmax = src[0]
    cdef DTYPE_t cmin = src[0]
    cdef np.ndarray result = np.zeros([length], dtype=DTYPE)
    cdef np.ndarray rotwindow = np.zeros([windowsize], dtype=DTYPE)
    cdef DTYPE_t cval = src[0]
    rotwindow[:] = cval
    cmax = cval
    cmin = cval
    cdef int rotpos = 0
    cdef int i
    cdef int remax = 0
    cdef int remin = 0
    for i in range(length):
        cval = src[i]
        # Does the current min/max drop out of the rotating window ? If so, remember that
        if (rotwindow[rotpos]==cmax):
            remax = 1
        if (rotwindow[rotpos]==cmin):
            remin = 1
        # Write current source value into rotating window
        rotwindow[rotpos] = cval
        # And increment window position
        rotpos = (rotpos + 1) % windowsize
        if (remax==1):
            cmax = np.max(rotwindow)
            if (cval>cmax):
                cmax = cval
        if (remin==1):
            cmin = np.min(rotwindow)
            if (cval<cmin):
                cmin = cval
        result[i] = cval-cmin
    return result

def downspan(np.ndarray[DTYPE_t, ndim=1] src,int windowsize):
    Calculate the windowed down-span (max - current value) of a numpy array
        src: numpy array with dtype=numpy.float32 and one dimension
        windowsize: integer giving the size of the sliding window
    returns: numpy array of same shape and type as src
    assert src.dtype == DTYPE
    assert windowsize>0
    cdef int length = src.shape[0]
    assert length>windowsize
    cdef DTYPE_t cmax = src[0]
    cdef DTYPE_t cmin = src[0]
    cdef np.ndarray result = np.zeros([length], dtype=DTYPE)
    cdef np.ndarray rotwindow = np.zeros([windowsize], dtype=DTYPE)
    cdef DTYPE_t cval = src[0]
    rotwindow[:] = cval
    cmax = cval
    cmin = cval
    cdef int rotpos = 0
    cdef int i
    cdef int remax = 0
    cdef int remin = 0
    for i in range(length):
        cval = src[i]
        # Does the current min/max drop out of the rotating window ? If so, remember that
        if (rotwindow[rotpos]==cmax):
            remax = 1
        if (rotwindow[rotpos]==cmin):
            remin = 1
        # Write current source value into rotating window
        rotwindow[rotpos] = cval
        # And increment window position
        rotpos = (rotpos + 1) % windowsize
        if (remax==1):
            cmax = np.max(rotwindow)
            if (cval>cmax):
                cmax = cval
        if (remin==1):
            cmin = np.min(rotwindow)
            if (cval<cmin):
                cmin = cval
        result[i] = cmax-cval
    return result

def spikedetect(np.ndarray[DTYPE_t, ndim=1] src, DTYPE_t threshold=0.38,  DTYPE_t span_threshold=0.29, 
                DTYPE_t span_increase_threshold=0.18, 
                int no_repeat_size=40, int windowsize=30):
    Spike detection function using a sliding window. It works by using multiple thresholds on
    absolute value of fluorescence as well as thresholds on upspan and increase of upspan within two timesteps.
    It also includes the functionality to block spike-reporting for a configurable number of timesteps after a
    detected spike.
    All of this in one optimized function. Thresholds and window-size are configurable. 
        src: Source numpy array with dtype=numpy.float32 and one dimension
        threshold: float threshold: Minimum absolute value to require for a spike
        span_threshold: float threshold for the minimum up-span we require
        span_increase_threshold: float threshold for the minimum increase in up-span we require
        no_repeat_size: int mimimum number of time steps after a spike which we won't report another spike
        windowsize: int size of the sliding window
    returns: numpy array of same shape and type as src. Spikes are reported as 1.0 values.
    assert src.dtype == DTYPE
    assert windowsize>0
    cdef int length = src.shape[0]
    assert length>windowsize
    assert no_repeat_size>=0
    cdef DTYPE_t cmax = src[0]
    cdef DTYPE_t cmin = src[0]
    cdef DTYPE_t upspan = 0.0
    cdef DTYPE_t upspan_increase = 0.0
    cdef np.ndarray result = np.zeros([length], dtype=DTYPE)
    cdef np.ndarray rotwindow = np.zeros([windowsize], dtype=DTYPE)
    cdef DTYPE_t cval = src[0]
    rotwindow[:] = cval
    cmax = cval
    cmin = cval
    cdef int rotpos = 0
    cdef int i
    cdef int remax = 0
    cdef int remin = 0
    cdef int lastspike = -1
    for i in range(length):
        # Does the current min/max drop out of the rotating window ? If so, remember that
        cval = src[i]
        # Does the current min/max drop out of the rotating window ? If so, remember that
        if (rotwindow[rotpos]==cmax):
            remax = 1
        if (rotwindow[rotpos]==cmin):
            remin = 1
        # Write current source value into rotating window
        rotwindow[rotpos] = cval
        # And increment window position
        rotpos = (rotpos + 1) % windowsize
        if (remax==1):
            cmax = np.max(rotwindow)
            if (cval>cmax):
                cmax = cval
        if (remin==1):
            cmin = np.min(rotwindow)
            if (cval<cmin):
                cmin = cval
        last_upspan_increase = upspan_increase
        upspan_increase = cval-cmin-upspan
        upspan = cval-cmin
        if  (
                    (upspan_increase>=span_increase_threshold or last_upspan_increase>=span_increase_threshold)
                and cval>=threshold
            if (lastspike<0 or ((i-lastspike)>no_repeat_size)):
                result[i] = 1.0
                lastspike = i
    return result

Visualizing Spike Detection

Now that we have spike detector functions, let's see how it performs on various networks. We are taking 200 visual samples from most of the networks we have, including the validation and test networks, to ensure our spike detector works reliably on known signals.

I used this visualization to iteratively tune the default threshold parameters etc. of the spikedetect(...) function above until it performed well.

In [20]:
def visualize_spike_detector(network, numframes=200, dpi=100):

    #spikedata, reconstruction = detect_spikes(store, network, 0.76)
    f = store[network]['fluorescence'].value
    fsize = (12,8)
    fig = plt.figure(figsize=fsize)
    writer = FFMpegWriter(fps=1, metadata=metadata, extra_args=['-acodec', 'libfaac', '-vcodec', 'libx264'])
    width = 500
    video = "neuron_spikes_%s.mp4" % (network)
    videopath = "%s/%s" % (nbdir, video)
    with writer.saving(fig,videopath , dpi):
        for i in range(numframes):
            neuron = np.random.randint(0, f.shape[0])
            start = np.random.randint(0, f.shape[1]-width)
            if (i == 0):
                # This is a particularly interesting case in normal-4 and normal-4-lownoise
                neuron = 34
                start = 82330 
            s = slice(start, start+width)
            print "%s: Writing Frame #%d / %d " % (network, i, numframes)
            plt.ylim(-0.2, 1.5)
            rdata = f[neuron,s]
            allf = smooth(rdata, 16)
            vol30 = volatility(rdata, 30)
            up30 = upspan(rdata, 30)
            down30 = downspan(rdata, 30)
            down30dt = down30[1:]-down30[:-1]
            up30dt = up30[1:]-up30[:-1]
            spiked = spikedetect(rdata)*1.2
            plt.plot(np.linspace(0.0, len(allf),len(allf)), allf, 'c-', label='Fluorescence (S16)', alpha=0.7)
            plt.plot(np.linspace(0.0, len(up30),len(up30)), up30, 'b-', label='Upspan', alpha=0.7)
            plt.plot(np.linspace(0.0, len(up30dt),len(up30dt)), up30dt, 'g-', label='Upspan/dt', alpha=0.7)
            plt.plot(np.linspace(0.0, len(spiked),len(spiked)), spiked, 'r-', label='Detected Spikes', alpha=0.8)
            plt.xlabel('Time step')
            plt.ylabel('Fluorescence level')
            plt.title('%s #%d t[%d to %d]' % (network, i, start, start+width))
    print "Done"
    return videopath

In [22]:
display_video_inline(visualize_spike_detector('normal-4', 30))

<matplotlib.figure.Figure at 0x64a5750>

This looks good. We can still tune it later on if we need even better accuracy. For now it does as good as I do. And of course I checked for more than 30 frames and 1 of the networks ;)

In [23]:
In [31]:
# Now, let's do the spike detection and save the results in our store

from sys import stdout
for nwname in store.keys():
    net = store[nwname]
    if ('spikes' in net.keys()):
        #del net['spikes']
    fluor = net['fluorescence'].value
    spikes = np.zeros_like(fluor, dtype=np.float32)
    ncount = fluor.shape[0]
    for n in range(ncount):
       print "Network: %s: %d / %d" % (nwname, n, ncount)
       spikes[n,:] = spikedetect(fluor[n,:])
    net.create_dataset("spikes", data=spikes)

Network: valid: 999 / 1000

The detection seems pretty good. We might have to tune the parameters a bit for weak spikes where I am in doubt at whether something is an actual spike or just a spike of a nearby neuron.

How do the Pros do it ?

See Scholarpedia article on Spike Sorting which explains spike detection and shape classification as it's usually done.

Why do we do it this way instead ?

We are interested in very accurate timing information. It can make a huge difference (causal-wise) when exactly the spike is triggered.